#!/usr/local/bin/python3

import os
import sys
import requests
from hashlib import md5
import urllib.parse as up
import matplotlib.pyplot as plt
from pyquery import PyQuery as pq

HEADERS = {
    'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_13_5) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/11.1.1 Safari/605.1.15'
}

def get_blog_posts(url, css_selector):
    '''获取博客文章列表
    :param url: 博客的归档页面地址
    :type url: 字符串
    :param css_selector: 用于选择文章条目对应<a>标签的CSS选择器
    :type css_selector: 字符串
    :return archive: 一个包含博客名称和博文列表的字典，博文列表包含标题、链接和日期信息
    :rtype: dict
    '''
    try:
        response = requests.get(url, headers=HEADERS)
        if response.status_code == 200:
            url_p = up.urlparse(url)
            host_url = '{0}://{1}'.format(url_p.scheme, url_p.netloc)
            doc = pq(response.content)
            archive = {
                'name': doc('head title').text().split('|')[1].strip(),
                'posts': []
            }
            for post in doc(css_selector).items():
                article = {
                    'title': post.text(),
                    'url': ''.join([host_url, post.attr.href]),
                    'date': post.attr.href[1:11]
                }
                archive['posts'].append(article)
            return archive
    except requests.ConnectionError:
        print('网络链接异常！请重试！')
        sys.exit(-1)

def statistic_posts_by_date(posts):
    '''按照日期对博文分类统计文章个数
    :param posts: 博文列表，每个元素包含标题、链接和日期信息
    :type posts: list
    :return statistics: 按照年份月份统一文章个数
    :rtype statistics: dict
    '''
    statistics = {}
    months = ['%02d'%x for x in range(1, 13)]
    for post in posts:
        year, month, day = post['date'].split('/')
        if statistics.get(year) == None:
            statistics[year] = {}
            for month_key in months:
                statistics[year][month_key] = 0
        statistics[year][month] += 1
    return statistics

def plot_statistics(statistics, title='posts yearly'):
    '''根据统计数据绘制曲线图
    :param statistics: 博文按照日期的统计数据
    :type statistics: dict
    :param title: 曲线图标题
    :type title: 字符串
    '''
    count = len(statistics)
    plt.close('all')
    plt.figure(figsize=(12, count * 2))
    plt.suptitle(title)
    row = count // 2 + 1
    sub_n = 1
    for year, data in statistics.items():
        plt.subplot(row, 2, sub_n)
        sub_n += 1
        x = range(len(data))
        plt.bar(x, data.values(), label=year)
        for a,b in zip(x, data.values()):
            plt.text(a, b+0.05, '%d' % b, ha='center', va= 'bottom',fontsize=7)
        plt.xticks(x, data.keys())
        plt.xlabel('Month')
        plt.ylabel('Count of posts')
        plt.ylim(0, 15)
        plt.legend()
    plt.show()

def get_post_content(post):
    '''获取指定博文的内容
    :param post: 博文
    :type post: dict
    :return post_content: 访问链接中博文内容部分的PyQuery对象
    :rtype: PyQuery
    '''
    url = post['url']
    try:
        response = requests.get(url, headers=HEADERS)
        if response.status_code == 200:
            doc = pq(response.content)
            return doc('.post')
        else:
            print('异常，状态码：', response.status_code)
    except requests.ConnectionError:
        print('\n异常...', post['title'], url, '\n')

def get_post_images(posts):
    '''获取博文列表中博文中所有图片的链接
    :param posts: 博文列表，每个元素包含博文标题、链接和日期
    :type posts: list
    :return images: image信息，包含所属博文标题、图片链接
    :rtype: dict
    '''
    for post in posts:
        post_content = get_post_content(post)
        if post_content:
            images = post_content.find('img').items()
            url_p = up.urlparse(post['url'])
            host_url = '{0}://{1}'.format(url_p.scheme, url_p.netloc)
            for image in images:
                image_url = image.attr.src
                if image_url[0] == '/':
                    image_url = host_url + image_url
                yield {
                    'title': post['title'],
                    'image': image_url
                }

def save_image_from(image_info, to_dir=''):
    '''根据链接获取图片，保存到标题命名的目录中，图片以md5码命名
    :param image_info: 包含标题和链接信息的字典数据
    :type image_info: dict
    :param to_dir: 要保存到的目录，默认值是空字符串，可指定目录，格式如: '狗狗'
    :type to_dir: unicode 字符串
    '''
    if image_info:
        print(image_info)
        image_dir = to_dir + '/' + image_info['title']
        if not os.path.exists(image_dir):
            os.makedirs(image_dir)
        try:
            response = requests.get(image_info['image'])
            if response.status_code == 200:
                image_path = '{0}/{1}.{2}'.format(image_dir, md5(response.content).hexdigest(), 'jpg')
                print(image_path)
                if not os.path.exists(image_path):
                    with open(image_path, 'wb') as f:
                        f.write(response.content)
                    print('图片下载完成！')
                else:
                    print('图片已下载 -> ', image_path)
        except requests.ConnectionError as e:
            print('图片下载失败！')
        except:
            print('出现异常！')

def parse_sys_args(args=sys.argv):
    '''解析命令参数
    '''
    help_info = '''\n    使用方法：
      %s 子命令 hexo博客归档地址 CSS选择器
      - 子命令：count 和 pic，count用于得到指定博客中每个月份博文数，pic用于爬取博文中的图片
      - hexo博客归档地址：比如http://www.litreily.top/archives/
      - CSS选择器：用于选择文章列表中文章标题对应的<a>标签
    举例：
      %s count http://www.litreily.top/archives/ '.post-archive .listing li a'
      %s pic http://www.smslit.top/archives/ '#posts article .post-title-link'

    ''' % (args[0], args[0], args[0])
    if len(args) != 4:
        print(help_info)
        sys.exit(-1)
    else:
        if args[1] != 'pic' and args[1] != 'count':
            print('请输入正确子命令！')
            print(help_info)
            sys.exit(-1)
        return (args[1], args[2], args[3])

if __name__ == '__main__':
    command, url, css_selector = parse_sys_args()
    archive = get_blog_posts(url, css_selector)
    print(archive['name'], '一共有', len(archive['posts']), '篇博文！')
    if len(archive['posts']) > 0:
        if command == 'count':
            plot_statistics(statistic_posts_by_date(archive['posts']), archive['name']+'\'s posts yearly')
        else:
            for image in get_post_images(archive['posts']):
                save_image_from(image, archive['name'])
    else:
        print('没有发现博文，请确认您写的css选择器是否合理！')

